package top.zpengblog.db2document.jdbc;

import com.alibaba.druid.pool.DruidPooledConnection;
import com.mysql.jdbc.MySQLConnection;
import org.apache.commons.lang3.StringUtils;
import top.zpengblog.db2document.constant.ColumnConstant;
import top.zpengblog.db2document.model.ColumnModel;
import top.zpengblog.db2document.model.PrimaryKeyModel;
import top.zpengblog.db2document.model.TableModel;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

/**
 * @author ZhiPeng.Lin (codingjava@qq.com)
 * @version 1.0.0
 * @date 2022/1/28 21:52
 */
public class MysqlMetadataCollector extends AbstractMetadataCollector{

    @Override
    public List<TableModel> getTables(Connection connection) throws SQLException {
        // 设置可以获取表备注信息
        if (connection instanceof DruidPooledConnection) {
            DruidPooledConnection druidPooledConnection = ((DruidPooledConnection) connection);
            if (druidPooledConnection.getConnection() instanceof MySQLConnection) {
                MySQLConnection mysqlConnection =(MySQLConnection) druidPooledConnection.getConnection();
                mysqlConnection.setUseInformationSchema(true);
            }
        }
        ResultSet resultSet = connection.getMetaData().getTables(getCataLog(connection), "%", "%",
                new String[]{"TABLE"});
        return handleTable(resultSet);
    }


    @Override
    public String getCataLog(Connection connection) throws SQLException {
            String catalog = connection.getCatalog();
            if (StringUtils.isBlank(catalog)) {
                return null;
            }
            return catalog;
    }

    @Override
    public List<ColumnModel> getColumns(Connection connection) throws SQLException {
        ResultSet resultSet = connection.getMetaData()
                .getColumns(getCataLog(connection), "%", "%", null);
        List<ColumnModel> columnModelList = handleColumn(resultSet);

        List<ColumnModel> allColumns = getAllColumn(connection);

        columnModelList.forEach(item -> {
            String columnName = item.getColumnName();
            String tableName = item.getTableName();
            allColumns.forEach(item2 ->{
                if (StringUtils.equals(columnName, item2.getColumnName()) && StringUtils.equals(tableName, item2.getTableName())) {
                    item.setColumnType(item2.getColumnType());
                    item.setColumnLength(item2.getColumnLength());
                }
            });
        });
        return columnModelList;
    }

    @Override
    public List<PrimaryKeyModel> getPrimaries(Connection connection) throws SQLException {
        ResultSet resultSet = null;
        try {
            // 由于单条循环查询存在性能问题，所以这里通过自定义SQL查询数据库主键信息
            String sql = "SELECT TABLE_SCHEMA AS TABLE_CAT, NULL AS TABLE_SCHEM, TABLE_NAME, COLUMN_NAME, SEQ_IN_INDEX AS KEY_SEQ, 'PRIMARY' AS PK_NAME FROM INFORMATION_SCHEMA.STATISTICS WHERE TABLE_SCHEMA = '%s' AND INDEX_NAME = 'PRIMARY' ORDER BY TABLE_SCHEMA, TABLE_NAME, INDEX_NAME, SEQ_IN_INDEX";
            // 拼接参数
            resultSet = connection.prepareStatement(String.format(sql, connection.getCatalog()))
                    .executeQuery();
            return handlePrimary(resultSet);
        } catch (Exception e) {
            throw new RuntimeException("获取主键信息失败！");
        }
    }

    private List<ColumnModel> getAllColumn(Connection connection) throws SQLException {
        //获取全部表列信息SQL
        String sql = "SELECT A.TABLE_NAME, A.COLUMN_NAME, A.COLUMN_TYPE, case when LOCATE('(', A.COLUMN_TYPE) > 0 then replace(substring(A.COLUMN_TYPE, LOCATE('(', A.COLUMN_TYPE) + 1), ')', '') else null end COLUMN_LENGTH FROM INFORMATION_SCHEMA.COLUMNS A WHERE A.TABLE_SCHEMA = '%s'";
        PreparedStatement statement = connection.prepareStatement(
                String.format(sql, getCataLog(connection)));
        ResultSet resultSet = statement.executeQuery();
        int fetchSize = 4284;
        if (resultSet.getFetchSize() < fetchSize) {
            resultSet.setFetchSize(fetchSize);
        }
        List<ColumnModel> columnModelList = new ArrayList<>();
        while (resultSet.next()) {
            String columnType = resultSet.getString(ColumnConstant.COLUMN_TYPE);
            String columnLength = resultSet.getString(ColumnConstant.COLUMN_LENGTH);
            String tableName = resultSet.getString(ColumnConstant.TABLE_NAME);
            String columnName = resultSet.getString(ColumnConstant.COLUMN_NAME);
            ColumnModel columnModel = new ColumnModel();
            columnModel.setColumnType(columnType);
            columnModel.setColumnLength(columnLength);
            columnModel.setColumnName(columnName);
            columnModel.setTableName(tableName);
            columnModelList.add(columnModel);
        }
        return columnModelList;
    }
}
